library(rpart)
library(rattle)
## Warning: package 'rattle' was built under R version 4.2.2
## Loading required package: tibble
## Loading required package: bitops
## Rattle: A free graphical interface for data science with R.
## VersiĂ³n 5.5.1 Copyright (c) 2006-2021 Togaware Pty Ltd.
## Escriba 'rattle()' para agitar, sacudir y  rotar sus datos.
library(tidyverse)
## ── Attaching packages
## ───────────────────────────────────────
## tidyverse 1.3.2 ──
## ✔ ggplot2 3.3.6      ✔ dplyr   1.0.10
## ✔ tidyr   1.2.1      ✔ stringr 1.4.1 
## ✔ readr   2.1.2      ✔ forcats 0.5.2 
## ✔ purrr   0.3.4      
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(plotly)
## Warning: package 'plotly' was built under R version 4.2.2
## 
## Attaching package: 'plotly'
## 
## The following object is masked from 'package:ggplot2':
## 
##     last_plot
## 
## The following object is masked from 'package:stats':
## 
##     filter
## 
## The following object is masked from 'package:graphics':
## 
##     layout

Data generating Process

\[ Z = f(x,y) + \epsilon = \sqrt{(x-9)^2 + (y-9)^2} + \epsilon \text{ con } \epsilon \sim \mathcal{N}(0,0.5) \]

#Vamos a crear un dataset sintético y graficarlo en 3D
set.seed(911)

n = 100

dtrain <- data.frame(x = runif(n,4.5,13.5),y = runif(n,4.5,13.5))
noise <- rnorm(n, mean=0, sd=0.5)
dtrain <- dtrain %>% mutate(z = sqrt((x-9)**2+(y-9)**2)+noise)


#VisualizaciĂ³n del dataset sintĂ©tico
plot_ly(dtrain, x = ~x, y = ~y, z = ~z) %>%
  add_markers(size = 1,color = I("lightblue"))
ggplot(dtrain, aes(x, y))+
  geom_point(color="orange")+
  theme_light()

$ Y = f(x) += + $

tree <- rpart(z ~ y + z, data = dtrain, method = "anova",maxdepth = 3, minsplit = 1, minbucket = 1, cp = 0)
## Warning in model.matrix.default(attr(frame, "terms"), frame): the response
## appeared on the right-hand side and was dropped
## Warning in model.matrix.default(attr(frame, "terms"), frame): problem with term
## 2 in model.matrix: no columns are assigned
fancyRpartPlot(tree)

fitted.values <- predict(tree, newdata = dtrain)

frame <- tree$frame
nodevec <- as.numeric(row.names(frame[frame$var == "<leaf>",])) #esto genera un vector con los nĂºmeros de nodos terminales
path.list <- path.rpart(tree, nodes = nodevec) #genera una lista en la cual cada elemento indica el camino a un nodo
## 
##  node number: 8 
##    root
##    y< 11.74
##    y>=6.857
##    y>=7.48
## 
##  node number: 9 
##    root
##    y< 11.74
##    y>=6.857
##    y< 7.48
## 
##  node number: 10 
##    root
##    y< 11.74
##    y< 6.857
##    y>=4.87
## 
##  node number: 11 
##    root
##    y< 11.74
##    y< 6.857
##    y< 4.87
## 
##  node number: 12 
##    root
##    y>=11.74
##    y>=11.74
##    y< 12.39
## 
##  node number: 13 
##    root
##    y>=11.74
##    y>=11.74
##    y>=12.39
## 
##  node number: 7 
##    root
##    y>=11.74
##    y< 11.74
rect_info <- NULL
for(path in path.list){
  path <- setdiff(path,"root")
  min.x = min(dtrain$x)
  max.x = max(dtrain$x)
  min.y = min(dtrain$y)
  max.y = max(dtrain$y)
  for(split in path){
    s <- unlist(str_split(split,"< |>="))
    var <- s[1]
    cutoff <- as.numeric(s[2])
    is.less <- str_detect(split,"< ")
    if(var == "x1"){
      if(is.less == TRUE){
        max.x <- cutoff
      } else {
        min.x <- cutoff
      }
    } else {
      if(is.less == TRUE){
        max.y <- cutoff 
      } else {
        min.y <- cutoff
      }
    }
  }
  rect_info <- rbind(rect_info,data.frame(xmin = min.x, xmax = max.x, ymin = min.y, ymax = max.y))
}


dtrain <- dtrain %>% 
  mutate(fitted.values = round(fitted.values,1))

label_points <- dtrain %>%
  group_by(fitted.values) %>%
  summarise(x = median(x), y = median (y))

ggplot() +
  geom_rect(data = rect_info,aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax),colour = "grey50", fill = "white") +
  geom_point(data = dtrain,aes(x = x, y = y, color = fitted.values)) +
  geom_label(data = label_points,aes(x = x, y = y, label = fitted.values, color=fitted.values)) +
  labs(color="Valor ajustado") +
  theme_light()